import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import os.path as osp  
import numpy as np 
import random
from pytorchcv.model_provider import get_model as ptcv_get_model
from dataset.oxford_pet import Databasket, CUB
from fastai.vision.all import *

import torchvision
import torchvision.transforms as transforms
from dataset import cifar

import os
import argparse

import torchvision.models as models
from utils import progress_bar

class CrossEntropyLabelSmooth(nn.Module):
    """Cross entropy loss with label smoothing regularizer.
    Reference:
    Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
    Equation: y = (1 - epsilon) * y + epsilon / K.
    Args:
        num_classes (int): number of classes.
        epsilon (float): weight.
    """

    def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True):
        super(CrossEntropyLabelSmooth, self).__init__()
        self.num_classes = num_classes
        self.epsilon = epsilon
        self.use_gpu = use_gpu
        self.reduction = reduction
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, targets):
        """
        Args:
            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
            targets: ground truth labels with shape (num_classes)
        """
        log_probs = self.logsoftmax(inputs)
        targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
        if self.use_gpu: targets = targets.cuda()
        targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
        loss = (- targets * log_probs).sum(dim=1)
        if self.reduction:
            return loss.mean()
        else:
            return loss
        return 

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=1e-2, type=float, help='learning rate')
parser.add_argument('--net', default='resnet50', type=str, help='architecture')
parser.add_argument('--bs', default=64, type=int, help='batch size')
parser.add_argument('--epoch','-e', default=20, type=int, help='batch size')
parser.add_argument('--input_size','-i', default=224, type=int, help='input image size') 
parser.add_argument('--crop_size', default=256, type=int, help='crop image size')
parser.add_argument('--workers','-w', default=4, type=int) 

parser.add_argument('--seed', default=0, type=int, help='random seed')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')  
parser.add_argument('--multigpu', '-m', action='store_true')
parser.add_argument('--test', '-t', action='store_true') 
parser.add_argument('--gpu', '-g', default='0', type=str, help='gpu id')
parser.add_argument('--tune_head',  action='store_true')
parser.add_argument('--resplit',  action='store_true')
parser.add_argument('--ls',  action='store_true')
parser.add_argument('--debug',  action='store_true')

parser.add_argument('--tgt', default='cifar100', type=str, help='target set')
parser.add_argument('--src', default='imagenet', type=str, help='source set')
parser.add_argument('--ssl', default=None, type=str, help='source ssl method')
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 
SEED = args.seed 
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

ckpt_src_dir = 'checkpoint/src/'+args.src+'/'+args.net
if args.ssl is not None:
    ckpt_tgt_dir = 'checkpoint_observ/' + args.src+'2'+args.tgt+'_' + args.ssl + '/'+args.net
    ckpt_tgt = ckpt_tgt_dir + '/ckpt.pth'
else:
    ckpt_tgt_dir = 'checkpoint_observ/' + args.src+'2'+args.tgt+'/'+args.net
    ckpt_tgt = ckpt_tgt_dir + '/ckpt.pth'
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.Resize([args.crop_size, args.crop_size]),
    transforms.RandomCrop([args.input_size, args.input_size]),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize([args.input_size, args.input_size]),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
if args.tgt == 'cifar10':

    NUM_CLS_TGT = 10
    trainset = cifar.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train, resplit = args.resplit)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)

    testset = cifar.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test, resplit = args.resplit)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)

    # classes = ('plane', 'car', 'bird', 'cat', 'deer',
    #         'dog', 'frog', 'horse', 'ship', 'truck')
elif args.tgt == 'cifar100':

    NUM_CLS_TGT = 100
    # trainset = torchvision.datasets.CIFAR100(
    trainset = cifar.CIFAR100(
        root='./data', train=True, download=False, transform=transform_train, resplit = args.resplit)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)

    # sample_size = len(train_dataset)

    testset = cifar.CIFAR100(
        root='./data', train=False, download=False, transform=transform_test, resplit = args.resplit)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
    # sample_size = int(len(testset)/2) 
    # sampler1 = torch.utils.data.sampler.SubsetRandomSampler(np.random.choice(range(len(testset)), sample_size))
    # testloader = torch.utils.data.DataLoader(
    #     testset, batch_size=args.bs, shuffle=False, sampler=sampler1, num_workers=args.workers)
elif args.tgt == 'imagenette':
    NUM_CLS_TGT = 10
    # path = untar_data(URLs.IMAGENETTE_320)
    train_root = './data/imagenette/train'
    test_root = './data/imagenette/val'
    trainset = torchvision.datasets.ImageFolder(train_root, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)
    testset = torchvision.datasets.ImageFolder(train_root, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)     
        
    # pass
elif args.tgt == 'oxfordpets':
    NUM_CLS_TGT = 37
    databasket = Databasket(train_transforms=transform_train, val_transforms=transform_test,resplit = args.resplit)
    trainset = databasket.train_ds
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)
    testset = databasket.val_ds
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False,  num_workers=args.workers)

    # sample_size = int(len(testset)/2) 
    # sampler1 = torch.utils.data.sampler.SubsetRandomSampler(np.random.choice(range(len(testset)), sample_size))
    # testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, sampler = sampler1, num_workers=args.workers)
elif args.tgt == 'oxfordflowers':
    NUM_CLS_TGT = 102
    trainset = torchvision.datasets.Flowers102(root='./data/', split = 'train', transform = transform_train, download=False)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.Flowers102(root='./data/', split = 'val', transform = transform_test, download=False)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers) 
    # exit()
    # pass
elif args.tgt == 'CUB':
    NUM_CLS_TGT = 200
    # if args.resplit:
    #     trainset = CUB(root='./data/CUB', is_train=False, transform=transform_train)
    #     trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)     
    #     testset = CUB(root='./data/CUB', is_train=True, transform=transform_test)
    #     testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers) 

    # else:
    trainset = CUB(root='./data/CUB', is_train=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)     
    testset = CUB(root='./data/CUB', is_train=False, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers) 
    # sample_size = int(len(testset)/2) 
    # sampler1 = torch.utils.data.sampler.SubsetRandomSampler(np.random.choice(range(len(testset)), sample_size))
    # testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, sampler = sampler1, num_workers=args.workers)
    # pass

elif args.tgt == 'DTD':
    NUM_CLS_TGT = 47
    trainset = torchvision.datasets.DTD(root='/data/yuhe.ding/DATA/DTD', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)
    testset = torchvision.datasets.DTD(root='./data/yuhe.ding/DATA/DTD', split = 'val', transform = transform_test, download=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)       
    # pass
elif args.tgt == 'food101':

    NUM_CLS_TGT = 101
    trainset = torchvision.datasets.Food101(root='./data/', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.Food101(root='./data/', split = 'test', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers) 
elif args.tgt == 'country211':
    NUM_CLS_TGT = 211
    trainset = torchvision.datasets.Country211(root='./data/', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.Country211(root='./data/', split = 'valid', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers) 
elif args.tgt == 'place365':
    NUM_CLS_TGT = 365
    trainset = torchvision.datasets.Places365(root='./data/', split = 'train-standard', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.Places365(root='./data/', split = 'val', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
elif args.tgt == 'stanfordcars':
    NUM_CLS_TGT = 196 
    trainset = torchvision.datasets.StanfordCars(root='./data/', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.StanfordCars(root='./data/', split = 'test', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
elif args.tgt == 'stl10':
    NUM_CLS_TGT = 10
    trainset = torchvision.datasets.STL10(root='./data/', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.STL10(root='./data/', split = 'test', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
elif args.tgt == 'svhn':
    NUM_CLS_TGT = 10
    trainset = torchvision.datasets.SVHN(root='./data/', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.SVHN(root='./data/', split = 'test', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
elif args.tgt == 'fgvcaircraft':
    NUM_CLS_TGT = 100
    trainset = torchvision.datasets.FGVCAircraft(root='./data/', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.FGVCAircraft(root='./data/', split = 'test', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
elif args.tgt == 'gtsrb':
    NUM_CLS_TGT = 43
    trainset = torchvision.datasets.GTSRB(root='./data/', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.GTSRB(root='./data/', split = 'test', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)



if args.src == 'imagenet':
    NUM_CLS_SRC = 1000
elif args.src == 'cifar10':
    NUM_CLS_SRC = 10
elif args.src == 'cifar100':
    NUM_CLS_SRC = 100

# Model
print('==> Building model..')
if args.net == 'resnet50':
    net = models.resnet50()
    ckpt_src = ckpt_src_dir + '/resnet50-v2.pth'
elif args.net == 'resnet18':
    net = models.resnet18(pretrained = True)
    exit()
elif args.net == 'resnet34':
    net = models.resnet34(pretrained = True)
    exit()
    # ckpt_src = ckpt_src_dir + '/resnet101-v2.pth'
    # ckpt_src = ckpt_src_dir + '/resnet101-v2.pth'
elif args.net == 'resnet101':
    net = models.resnet101()
    ckpt_src = ckpt_src_dir + '/resnet101-v2.pth'
elif args.net == 'resnet152':
    net = models.resnet152()
    ckpt_src = ckpt_src_dir + '/resnet152-v2.pth'
elif args.net == 'densenet169':
    net = models.densenet169(pretrained=True)
    if not args.src == 'imagenet':
        ckpt_src = ckpt_src_dir + '/densenet169.pth' 
elif args.net == 'densenet121':
    net = models.densenet121(pretrained=True)
    if not args.src == 'imagenet':
        ckpt_src = ckpt_src_dir + '/densenet121.pth'
elif args.net == 'densenet201':
    net = models.densenet201(pretrained=True)
    if not args.src == 'imagenet':
        ckpt_src = ckpt_src_dir + '/densenet201.pth'
elif args.net == 'mobilenetv1':
    net = ptcv_get_model("mobilenet_w1", pretrained=False)
    ckpt_src = ckpt_src_dir + '/mobilenet_w1-0895-7e1d739f.pth' 
elif args.net == 'mobilenetv2':
    net = models.mobilenet_v2()
    ckpt_src = ckpt_src_dir + '/mobilenet_v2.pth'
elif args.net == 'mobilenetv3_large': 
    net = models.mobilenet_v3_large() 
    ckpt_src = ckpt_src_dir + '/mobilenet_v3_large-8738ca79.pth' 
elif args.net == 'mobilenetv3_small':  
    net = models.mobilenet_v3_small()
    ckpt_src = ckpt_src_dir + '/mobilenet_v3_small-047dcff4.pth'  
elif args.net == 'swin_b':
    net = models.swin_b()
    ckpt_src = ckpt_src_dir + '/swin_b.pth'
elif args.net == 'swin_v2_b':
    net = models.swin_v2_b()
    ckpt_src = ckpt_src_dir + '/swin_v2_b.pth'
elif args.net == 'vit_b_16':
    net = models.vit_b_16()
    ckpt_src = ckpt_src_dir + '/vit_b_16.pth'
elif args.net == 'wide_resnet101_2':
    net = models.wide_resnet101_v2()
    ckpt_src = ckpt_src_dir + '/wide_resnet101_2.pth'
elif args.net == 'efficientnetb0':
    net = models.efficientnet_b0()
    ckpt_src = ckpt_src_dir + '/efficientnet_b0_rwightman-3dd342df.pth'
elif args.net == 'efficientnetb1':
    net = models.efficientnet_b1() 
    ckpt_src = ckpt_src_dir + '/efficientnet_b1_rwightman-533bc792.pth'
elif args.net == 'efficientnetb2':
    net = models.efficientnet_b2()
    ckpt_src = ckpt_src_dir + '/efficientnet_b2_rwightman-bcdf34b7.pth' 
elif args.net == 'efficientnetb3':
    net = models.efficientnet_b3()
    ckpt_src = ckpt_src_dir + '/efficientnet_b3_rwightman-cf984f9c.pth'
elif args.net == 'vgg16':
    net = models.vgg16()
    ckpt_src = ckpt_src_dir + '/vgg16-397923af.pth'
elif args.net == 'vgg19':
    net = models.vgg19()
    ckpt_src = ckpt_src_dir + '/vgg19-dcbb9e9d.pth'
elif args.net == 'squeezenet10':
    net = models.squeezenet1_0()
    ckpt_src = ckpt_src_dir + '/squeezenet10.pth'
    exit()
elif args.net == 'squeezenet11':
    net = models.squeezenet1_1()
    ckpt_src = ckpt_src_dir + '/squeezenet11.pth'
    exit()
elif args.net == 'inceptionv3':
    net = models.inception_v3(pretrained=True)
    ckpt_src = ckpt_src_dir + '/inceptionv3.pth'
    exit()
elif args.net == 'googlenet':
    net = models.googlenet(pretrained=True)
    ckpt_src = ckpt_src_dir + '/googlenet.pth'
    exit()


print('==> Resuming from source checkpoint..')

if args.ssl == 'moco':
    ckpt_src = '/home/yuhe.ding/code/TE/checkpoint/src/moco_v1_200ep_pretrain.pth.tar'
elif args.ssl == 'mocov2':
    ckpt_src = '/home/yuhe.ding/code/TE/checkpoint/src/moco_v2_200ep_pretrain.pth.tar'
elif args.ssl == 'simclr':
    ckpt_src = '/home/yuhe.ding/code/TE/checkpoint/src/resnet50_imagenet_bs2k_epochs200.pth.tar'

if args.ssl is None:
    # ckpt = torch.load(ckpt_src)
    if not args.src == 'imagenet':
        ckpt = torch.load(ckpt_src)
        net.fc = torch.nn.Linear(2048, NUM_CLS_SRC)
        net.load_state_dict(ckpt['net'])
    else:
        if not args.net.startswith('densenet'):
            ckpt = torch.load(ckpt_src)
            net.load_state_dict(ckpt)

elif args.ssl.startswith('moco'):
    ckpt = torch.load(ckpt_src)
    net.fc = torch.nn.Linear(2048, 128)
    model_dict = net.state_dict()
    new_dict = {k[17:]: v for k, v in ckpt["state_dict"].items()}
    # backbone_dict = {k: v for k, v in ckpt["state_dict"].items() if k in model_dict.items()}
    model_dict.update(new_dict)
    net.load_state_dict(model_dict, strict=False)
    # for param_tensor, v in ckpt["state_dict"].items():
        # param_tensor = param_tensor[17:]
        # print(param_tensor)

    # exit()
    # net.load_state_dict(ckpt["state_dict"])
    # net.load_state_dict(ckpt['net'])
elif args.ssl == 'simclr':
    ckpt = torch.load(ckpt_src)
    model_dict = net.state_dict()
    new_dict = {k[8:]: v for k, v in ckpt["state_dict"].items()}
    model_dict.update(new_dict)
    net.load_state_dict(model_dict, strict=False)
    # net.load_state_dict(ckpt["state_dict"])
    # exit()
# else:
# for name, module in net.named_children():
#         print(name, module)
        # for s in act_list:
        #     count += str(module).count(s)

if args.net.startswith('resnet'):
    net.fc = torch.nn.Linear(2048, NUM_CLS_TGT)
    for k,v in net.named_parameters():
        if args.tune_head:
            v.requires_grad=False
            # if 'fc' in k:
            if k.startswith('fc'):
                print(k)
                v.requires_grad=True
        else:
            v.requires_grad=True
elif args.net.startswith('densenet'):
    input_dim = 1664
    if args.net == 'densenet201':
        input_dim = 1920
    elif args.net == 'densenet121':
        input_dim = 1024
    net.classifier = torch.nn.Linear(input_dim, NUM_CLS_TGT)
    for k,v in net.named_parameters():
        if args.tune_head:
            v.requires_grad=False
            # if 'fc' in k:
            if k.startswith('classifier'):
                print(k)
                v.requires_grad=True
        else:
            v.requires_grad=True
elif args.net.startswith('efficientnet'):
    input_dim = 1280
    if args.net == 'efficientnetb2':
        input_dim = 1408
    elif args.net == 'efficientnetb3':
        input_dim = 1536
    net.classifier = nn.Sequential(
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(input_dim, NUM_CLS_TGT),
        )
    for k,v in net.named_parameters():
        if args.tune_head:
            v.requires_grad=False
            # if 'fc' in k:
            if k.startswith('classifier'):
                print(k)
                v.requires_grad=True
        else:
            v.requires_grad=True
elif args.net.startswith('monilenet'):
    net.output = torch.nn.Linear(2048, NUM_CLS_TGT)
    for k,v in net.named_parameters():
        if args.tune_head:
            v.requires_grad=False
            # if 'fc' in k:
            if k.startswith('output'):
                print(k)
                v.requires_grad=True
        else:
            v.requires_grad=True
elif args.net.startswith('vgg'):
    net.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, NUM_CLS_TGT),
        )
    # net.classifier.6 = torch.nn.Linear(2048, NUM_CLS_TGT)
    for k,v in net.named_parameters():
        if args.tune_head:
            v.requires_grad=False
            # if 'fc' in k:
            print(k)
            if k.startswith('classifier'):
                print(k)
                v.requires_grad=True
        else:
            v.requires_grad=True




# exit()
net = net.to(device)
if args.multigpu and device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

if args.resume or args.test:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir(ckpt_tgt_dir), 'Error: no checkpoint directory found!'
    print(ckpt_tgt)
    checkpoint = torch.load(ckpt_tgt)
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

if args.ls:
    criterion = classifier_loss = CrossEntropyLabelSmooth(num_classes=NUM_CLS_TGT)
else:
    criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epoch)

def print_args(args):
    s = "==========================================\n"
    for arg, content in args.__dict__.items():
        s += "{}:{}\n".format(arg, content)
    return s

if not os.path.isdir(ckpt_tgt_dir):
    os.makedirs(ckpt_tgt_dir)

if (not args.test) and (not args.debug):
    args.out_file = open(osp.join(ckpt_tgt_dir, 'log.txt'), 'w') 
    args.out_file.write(print_args(args)+'\n') 
    args.out_file.flush()  

# Training
def train(epoch):
    print('\n' + args.tgt + ' on '+ args.net + ' ' + 'Epoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        # print(outputs.size())
        loss = criterion(outputs, targets)
        # print(loss)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    if not args.debug:
        args.out_file.write('\nEpoch: %d' % epoch + '\n' + 'Train | Loss: %.3f | Acc: %.3f%% (%d/%d)'% (train_loss/(batch_idx+1), 100.*correct/total, correct, total)+'\n')
        args.out_file.flush()   

        acc = 100.*correct/total
    
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        torch.save(state, f'{ckpt_tgt_dir}/ckpt_{epoch}.pth')
        best_acc = acc

def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    if not (args.test or args.debug):
        args.out_file.write('Test | Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)+'\n')
        args.out_file.flush()
    
    # Save checkpoint.


if args.test:
    test(0)
    exit()
for epoch in range(start_epoch, start_epoch+args.epoch):
    train(epoch)
    test(epoch)
    scheduler.step()
